# Change the path to the location of Earwigs.csv on your computer
EW <- read_csv("../Data/Earwigs.csv", show_col_types = FALSE)
ggplot() +
geom_point(data = EW, aes(Density, Proportion_forceps),
color = "steelblue", size = 3)Problem Set 4
In this problem set, before we get to the models, we want give some demonstration code about two packages that help with some aspects of Bayesian inference. We’ll walk through code so you have templates to use for later analyses in this problem set.
priorsense Package
Usually, those new to Bayesian analysis have many questions about how to set a prior and how to evaluate whether the chosen prior is appropriate. We will show you a new package that helps to assess this second question - is my prior appropriate? This package is the priorsense package1. If you are interested is a less technical introduction, here is a short video describing the package. We decided to leave this out of the lectures in part because it would have taken time away from other topics and because it is an area of still active research.
The approach that Kallioinen et al. take is to use importance sampling (as in PSIS-LOO-CV) of the prior or likelihood raised to exponent (the “power” in power-scaling) to detect instances of prior-data conflict wherein the prior contains too much information (e.g., is too constrained) or the likelihood (data) has too little information, or some combination of the two.
The approach is simply to give either the prior or the likelihood more (or less) power by raising it to an exponent \(\alpha\) that varies around 1 (i.e., no scaling). For example, in testing the prior \(Pr(\theta)\) sensitivity:
\[Pr(\theta|y) \sim Pr(y|\theta) Pr(\theta)^\alpha\]
the prior is raised to an exponent \(\alpha\) that can vary. The response of the posterior \(Pr(\theta|y)\) to changing the strength of the prior tells us how sensitive the model is to the prior. Note here that we are only looking at the numerator of Bayes’ Rule, because MCMC methods make dealing with the probability of the data \(Pr(y)\) unnecessary.
The prior scaling approach is a complementary to prior predictive simulation that we have been using so far. The general approach would be to develop priors via prior predictive simulation (using the different options in ulam() or brm() to sample from the prior only) and then check that those priors are adequate using the functions in priorsense.
The main function of this package are:
powerscale_sequence()evaluates the prior/likelihood sensitivity across a range of powers. This function can be wrapped in eitherpowerscale_plot_dens()orpowerscale_plot_quantities()to plot changes in the posterior densities or the dependency of the posterior on prior or likelihood scaling, respectively.powerscale_sensitivity()is the main function to test the sensitivity of the prior and likelihood via power-scaling.
The package is not (yet) on CRAN, so you have to install it directly from github: remotes::install_github("n-kall/priorsense")
We will use the Earwigs data from Problem Set 3 to explore how to use this package.
Load the data and plot:
Check the variables that can have priors in the brm() model:
get_prior(Proportion_forceps ~ 1 + Density,
data = EW) prior class coef group resp dpar nlpar lb ub
(flat) b
(flat) b Density
student_t(3, 0.3, 2.5) Intercept
student_t(3, 0, 2.5) sigma 0
source
default
(vectorized)
default
default
We will skip the prior predictive simulation in this example. To use the priorsense functions, you have to fit the model with the data (i.e., not sampling from the prior only). So in practice you would do the prior predictive simulation here, sampling from the prior to get a prospective set of priors. Then you would use the functions in priorsense to evaluate the priors with the model.
For now, we will set the priors to be very bad (which we know from doing Problem Set 3), to see what the diagnostics look like:
fm <- brm(Proportion_forceps ~ 1 + Density,
data = EW,
prior = c(prior(normal(0, 0.0001), class = b),
prior(normal(0, 0.0001), class = Intercept),
prior(normal(0, 0.01), class = sigma)),
refresh = 0)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.5 seconds.
Loading required package: rstan
Loading required package: StanHeaders
rstan version 2.35.0.9000 (Stan version 2.35.0)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
change `threads_per_chain` option:
rstan_options(threads_per_chain = 1)
Attaching package: 'rstan'
The following object is masked from 'package:tidyr':
extract
Power-scale sensitivity visual diagnostics
powerscale_plot_dens() plots overlapping density plots color coded by the range of `\(\alpha\) power-scaling exponents in the tested range. When the lines overlap, it indicates that the density estimate is not sensitive to power-scaling. Considering the prior, for example, if the lines differ, then it means that the prior density changes when scaled by a power (bad). So what we want to see is that the density lines for the prior are superimposed. For the likelihood, the lines should not overlap, indicating that the prior and the data are able to change the likelihood.
Because the scales of the variables are so different, we will plot them separately. Notice the embedded powerscale_sequence(fm). We could pre-compute this and pass to the function just as well (if there were more data would be advantageous for speed).
powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "b_Density")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "b_Intercept")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "sigma")You can see how the all of the priors are sensitive to scaling, particularly sigma. We also get messages about high Pareto \(k\) value, indicating poor fit.
powerscale_plot_quantities() visualizes the rate of change in the posterior as \(\alpha\) changes. Ideally we would like to see a flat-ish line for the prior, indicating that the prior is not sensitive to scaling. There are many options for the divergence measure, but the default “Cumulative Jensen-Shannon distance” (cjs_dist) seems to work fine.
powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Density", "b_Intercept", "sigma")
)Power-scale sensitivity table
Finally, the functionpowerscale_sensitivity() makes a table of sensitivity values. Values >0.05 indicates sensitivity of the prior or likelihood. The last column provides a diagnosis.
powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 4 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_Intercept 0.103 0.0488 strong prior / weak likelihood
2 b_Density 0.102 0.0485 strong prior / weak likelihood
3 sigma 1.03 0.987 prior-data conflict
4 Intercept 0.0662 0.0468 strong prior / weak likelihood
Here we have a “weak likelihood” for the first two rows, because the priors on the means and Intercept are way too strong (Normal(0, 0.0001)). This means that the data are insufficient to move the likelihood away from the prior.
The prior for sigma is also poor, resulting in a prior-data conflict, where one goes up and one goes down as \(\alpha\) changes
Improved priors
Let’s use the priors that we developed for problem set 3 and hopefully see a better pattern.
fm <- brm(Proportion_forceps ~ 1 + Density,
data = EW,
prior = c(prior(normal(0, 0.1), class = b),
prior(normal(0, 1), class = Intercept),
prior(normal(0, 1), class = sigma)),
refresh = 0,
iter = 5e3)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.6 seconds.
Evaluating the priors:
powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "b_Density")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "b_Intercept")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "sigma")powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Density", "b_Intercept", "sigma")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 4 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_Intercept 0.000985 0.0887 -
2 b_Density 0.000219 0.0884 -
3 sigma 0.00136 0.226 -
4 Intercept 0.00169 0.0937 -
Notice how in the density plots, densities of the priors are all overlapping (lack of sensitivity) and the posteriors are not overlapping (the data is able to inform the posterior). The quantities plot shows relatively flat lines for the priors and likelihoods that are sensitive to scaling. Finally the table has all values < 0.05 for the prior.
tidybayes Package
We want to add one more set of analysis tools to our general Bayesian inference kit: the tidybayes package. tidybayes has a variety of functions for extracting parts of fit models (from lots of model fitting interfaces), augmenting model fits with various kinds of predicted values, and making some very impressive visualizations.
The documentation has a [page of visualizations from brms models]http://mjskay.github.io/tidybayes/articles/tidy-brms.html().
tidybayes is particularly useful for working with the posteriors of multilevel models, which is what the demo code that the documentation provides is based on. Our usage here will be a little more pedestrian, but we can still see how useful the package can be.
We will adapt some of the demo code to plot the posterior for the earwigs model we just fit.
When you are trying to figure out what the variable names are in a model, the function get_variables() returns them:
library(tidybayes)
get_variables(fm) [1] "b_Intercept" "b_Density" "sigma" "Intercept"
[5] "lprior" "lp__" "accept_stat__" "treedepth__"
[9] "stepsize__" "divergent__" "n_leapfrog__" "energy__"
By default, brm() returns parameters prepended with b_ for “b” parameters (what are often called main or fixed effects) and r_ for random/multilevel effects (though we aren’t doing multilevel models in this module, it’s useful to know).
tidy_draws() is the simplest way to extract a posterior. You can see how it returns a lot of diagnostics as well: acceptance statistic, tree depth, step size, and whether that draw was a divergence or not.
tidybayes has a summary()-like function summarise_draws() (Commonwealth spelling only). We can pipe the output of tidy_draws() directly to it.
fm |> tidy_draws()# A tibble: 10,000 × 15
.chain .iteration .draw b_Intercept b_Density sigma Intercept lprior lp__
<int> <int> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 1 1 0.128 0.00454 0.193 0.244 0.190 6.36
2 1 2 2 0.227 0.00512 0.148 0.358 0.163 6.48
3 1 3 3 0.0980 0.00490 0.209 0.223 0.191 5.21
4 1 4 4 0.132 0.00488 0.199 0.257 0.185 6.59
5 1 5 5 0.209 0.00505 0.138 0.339 0.171 6.97
6 1 6 6 0.243 0.00436 0.146 0.355 0.164 6.47
7 1 7 7 0.154 0.00561 0.163 0.298 0.180 8.06
8 1 8 8 0.187 0.00555 0.151 0.329 0.172 7.72
9 1 9 9 0.234 0.00435 0.141 0.345 0.168 6.75
10 1 10 10 0.118 0.00445 0.208 0.232 0.189 5.51
# ℹ 9,990 more rows
# ℹ 6 more variables: accept_stat__ <dbl>, treedepth__ <dbl>, stepsize__ <dbl>,
# divergent__ <dbl>, n_leapfrog__ <dbl>, energy__ <dbl>
fm |> tidy_draws() |> summarise_draws()# A tibble: 12 × 10
variable mean median sd mad q5 q95 rhat ess_bulk
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 b_Interc… 0.174 0.173 0.0592 0.0583 0.0765 0.271 1.00 7701.
2 b_Density 0.00507 0.00507 0.00178 0.00174 0.00217 0.00796 1.00 9581.
3 sigma 0.173 0.170 0.0296 0.0279 0.132 0.227 1.00 5483.
4 Intercept 0.303 0.303 0.0370 0.0356 0.242 0.364 1.00 6390.
5 lprior 0.175 0.177 0.0127 0.0116 0.153 0.193 1.00 6495.
6 lp__ 6.55 6.88 1.29 1.07 3.97 7.95 1.00 4021.
7 accept_s… 0.928 0.962 0.0921 0.0529 0.741 1 1.00 12439.
8 treedept… 2.39 2 0.555 0 2 3 1.00 8879.
9 stepsize… 0.571 0.568 0.0145 0.0142 0.554 0.593 Inf 4.01
10 divergen… 0 0 0 0 0 0 NA NA
11 n_leapfr… 5.47 7 2.01 0 3 7 1.00 8805.
12 energy__ -5.05 -5.39 1.77 1.62 -7.30 -1.67 1.00 3791.
# ℹ 1 more variable: ess_tail <dbl>
Many of the tidybayes functions require the posterior to be in a slightly different format, that of an rvar. An rvar is a compact way to store a distribution of values. We can use spread_rvars() to extract only a few of the columns. We then pipe that output to median_hdi() to get the median 89% HDI of the posterior for each.
fm |>
spread_rvars(b_Intercept, b_Density, sigma)# A tibble: 1 × 3
b_Intercept b_Density sigma
<rvar[1d]> <rvar[1d]> <rvar[1d]>
1 0.17 ± 0.059 0.0051 ± 0.0018 0.17 ± 0.03
fm |>
spread_rvars(b_Intercept, b_Density, sigma) |>
median_hdi(.width = 0.89)# A tibble: 1 × 12
b_Intercept b_Intercept.lower b_Intercept.upper b_Density b_Density.lower
<dbl> <dbl> <dbl> <dbl> <dbl>
1 0.173 0.0843 0.272 0.00507 0.00229
# ℹ 7 more variables: b_Density.upper <dbl>, sigma <dbl>, sigma.lower <dbl>,
# sigma.upper <dbl>, .width <dbl>, .point <chr>, .interval <chr>
There are many options in tidybayes to plot distributions and intervals. Here is a point + interval plot of the three main parameters.
fm |>
spread_rvars(b_Intercept, b_Density, sigma) |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))b_Density is very small relative to the other parameters, so it’s variation looks really small in comparison. We might just plot it separately:
fm |>
spread_rvars(b_Density) |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))From this plot you can see how the posterior is credibly different from 0, even though the parameter estimate is small.
tidybayes works well with data_grid() from the modelr package. Like crossing() that we have used before, data_grid() generates the pairwise combinations of variables that are passed to it, but without needing to include as many details (it will by default use the range of continuous variables and all the levels of factors).
If we then pipe that out to add_epred_draws() called with the fitted model, we can create a tibble with the values of Density across a range paired with the expected value of Proportion_forceps.
The second block of code pipes these values to ggplot() to make a plot of the observed data along with ribbons representing the 50%, 89%, and 97% HDIs for the expected values. Remember that these values do not include the standard deviation, so they are relatively narrow.
library(modelr)
# Expected parameter estimates
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_epred_draws(fm)# A tibble: 2,000,000 × 6
# Groups: Density, .row [200]
Density .row .chain .iteration .draw .epred
<dbl> <int> <int> <int> <int> <dbl>
1 0.152 1 NA NA 1 0.128
2 0.152 1 NA NA 2 0.228
3 0.152 1 NA NA 3 0.0987
4 0.152 1 NA NA 4 0.132
5 0.152 1 NA NA 5 0.210
6 0.152 1 NA NA 6 0.244
7 0.152 1 NA NA 7 0.155
8 0.152 1 NA NA 8 0.188
9 0.152 1 NA NA 9 0.234
10 0.152 1 NA NA 10 0.119
# ℹ 1,999,990 more rows
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_epred_draws(fm) |>
ggplot(aes(x = Density, y = Proportion_forceps)) +
stat_lineribbon(aes(y = .epred),
.width = c(0.5, 0.89, 0.97),
alpha = 0.5) +
geom_point(data = EW)We can do the same but generate a posterior predictive distribution plot by calling add_predicted_draws() instead (note that the variable is .prediction rather than .epred).
# Posterior predictive distribution
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_predicted_draws(fm) |>
ggplot(aes(x = Density, y = Proportion_forceps)) +
stat_lineribbon(aes(y = .prediction),
.width = c(0.5, 0.89, 0.97),
alpha = 0.5) +
geom_point(data = EW)Almost all of the points fall within the 97% interval, just like we would predict. Observe that the lines and edges are pretty rough. We could sample more iterations to smooth those out.
In the analyses below, try to add the packages above to your now pretty well-developed Bayesian modeling routines. Also see if you can work with the mcmc_ functions from bayesplot and pp_check() for plotting prior/posterior predictive checks.
These are three models that you saw in Quantitative Methods 1. We will leave much of the details of the analysis to you, providing some guidance for three challenging kinds of models to fit and interpret.
ANOVA-like
The data in Heart_Transplants.csv has data on the Survival time (in days) for heart transplant patients with varying degrees of Mismatch between the donor and recipient. You will need to convert Mismatch to a factor and get the factor in the correct order: low, medium, high. Low indicates a relatively good match and high a poor match. The data have a pronounced right skew.
Load the data, visualize, and transform how you see fit.
# FIXME
HT <- read_csv("../Data/Heart_Transplants.csv", show_col_types = FALSE) |>
mutate(Mismatch = fct_inorder(Mismatch))
ggplot(HT, aes(x = Mismatch, y = Survival)) +
geom_point(position = position_jitter(width = 0.1, seed = 4564356)) +
stat_summary(fun = mean, geom = "point", size = 3, color = "red") +
stat_summary(fun.data = mean_se, geom = "errorbar",
width = 0.1,
linewidth = 1,
color = "red")HT |>
group_by(Mismatch) |>
summarize(mean_Survival = mean(Survival),
sd_Survival = sd(Survival))# A tibble: 3 × 3
Mismatch mean_Survival sd_Survival
<fct> <dbl> <dbl>
1 Low 311. 431.
2 Medium 269 339.
3 High 71.9 86.0
HT <- HT |>
mutate(logSurvival = log(Survival))
ggplot(HT, aes(x = Mismatch, y = logSurvival)) +
geom_point(position = position_jitter(width = 0.1, seed = 4564356)) +
stat_summary(fun = mean, geom = "point", size = 3, color = "red") +
stat_summary(fun.data = mean_se, geom = "errorbar",
width = 0.1,
linewidth = 1,
color = "red")Model specification
\[\begin{align} \mathrm{logSurvival} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Mismatch}] \\ \end{align}\]
Prior specification and prior predictive check
# FIXME
get_prior(logSurvival ~ Mismatch - 1,
data = HT) prior class coef group resp dpar nlpar lb ub
(flat) b
(flat) b MismatchHigh
(flat) b MismatchLow
(flat) b MismatchMedium
student_t(3, 0, 2.5) sigma 0
source
default
(vectorized)
(vectorized)
(vectorized)
default
PP <- brm(logSurvival ~ Mismatch - 1,
data = HT,
prior = c(prior(normal(4, 3), class = b)),
refresh = 0,
sample_prior = "only")Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
pp_check(PP, type = "stat_grouped", group = "Mismatch",
stat = "mean",
ndraws = 500,
binwidth = 0.5)Final model specification
\[\begin{align} \mathrm{logSurvival} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Mismatch}] \\ b[\mathrm{Mismatch}] & \sim Normal(0, 3) \\ \sigma & \sim HalfNormal(0, 2) \end{align}\]
Sampling
# FIXME
fm <- brm(logSurvival ~ Mismatch - 1,
data = HT,
prior = c(prior(normal(4, 3), class = b),
prior(normal(0, 3), class = sigma)),
refresh = 0,
iter = 5e3)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub source
normal(4, 3) b user
normal(4, 3) b MismatchHigh (vectorized)
normal(4, 3) b MismatchLow (vectorized)
normal(4, 3) b MismatchMedium (vectorized)
normal(0, 3) sigma 0 user
Diagnostics
# FIXME
powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_MismatchLow", "b_MismatchMedium", "b_MismatchHigh"))powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_MismatchLow", "b_MismatchMedium", "b_MismatchHigh")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 4 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_MismatchLow 0.00471 0.0913 -
2 b_MismatchMedium 0.00725 0.0839 -
3 b_MismatchHigh 0.00431 0.0846 -
4 sigma 0.00907 0.186 -
# FIXME
summary(fm) Family: gaussian
Links: mu = identity; sigma = identity
Formula: logSurvival ~ Mismatch - 1
Data: HT (Number of observations: 39)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
MismatchLow 4.49 0.43 3.64 5.33 1.00 11083 6953
MismatchMedium 4.78 0.45 3.89 5.67 1.00 11549 7181
MismatchHigh 3.73 0.47 2.80 4.64 1.00 10750 7304
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 1.61 0.20 1.27 2.05 1.00 9351 7130
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
# FIXME
mcmc_trace(fm)mcmc_rank_overlay(fm)Posterior predictive simulation
# FIXME
pp_check(fm, type = "stat_grouped", group = "Mismatch",
stat = "mean",
ndraws = 500,
binwidth = 0.1)Summarizing the posterior
# FIXME
# Median HDI
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
median_hdi(.width = 0.89)# A tibble: 1 × 12
Low Low.lower Low.upper Medium Medium.lower Medium.upper High High.lower
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4.49 3.86 5.21 4.78 4.09 5.52 3.73 2.98
# ℹ 4 more variables: High.upper <dbl>, .width <dbl>, .point <chr>,
# .interval <chr>
# Lots of different options for visualizing
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
pivot_longer(cols = everything()) |>
mutate(name = fct_inorder(name)) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))HT |>
data_grid(Mismatch) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Mismatch)) +
stat_slab()HT |>
data_grid(Mismatch) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Mismatch)) +
stat_interval(.width = c(0.50, 0.89, 0.97)) +
geom_point(aes(x = logSurvival), data = HT) +
scale_color_brewer()# Kruschke plot
library(distributional)
HT |>
data_grid(Mismatch) |>
add_epred_draws(fm, dpar = c("mu", "sigma")) |>
sample_draws(30) |>
ggplot(aes(y = Mismatch)) +
stat_slab(aes(xdist = dist_normal(mu = mu, sigma = sigma)),
slab_color = "gray65", alpha = 0.1, fill = NA) +
geom_point(aes(x = logSurvival), data = HT, shape = 21,
fill = "#9ECAE1", size = 3)Test the hypothesis that Medium and High mismatch differ from Low using contrasts.
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
mutate(Med_v_Low = Medium - Low,
High_v_Low = High - Low,
.keep = "none") |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.5, 0.89))2x2 factorial design
The file Bird_Plasma.xlsx contains factorial data on blood plasma calcium concentration (Calcium, in mg Ca per 100 mL plasma) in male and female birds (Sex) each of which was treated or not with a hormone (Treatment).
- Load the data, and convert hormone and sex to factors.
- The levels of
Treatmentare “Hormone” and “None”. RelevelTreatmentso that “None” is the base level. - Plot a reaction norm of Calcium vs. Sex, with color encoding Treatment to get a sense for the pattern.
# FIXME
BP <- readxl::read_excel("../Data/Bird_Plasma.xlsx") |>
mutate(Treatment = factor(Treatment),
Sex = factor(Sex),
Treatment = fct_relevel(Treatment, "None"))
BP |> count(Treatment, Sex)# A tibble: 4 × 3
Treatment Sex n
<fct> <fct> <int>
1 None Female 5
2 None Male 5
3 Hormone Female 5
4 Hormone Male 5
ggplot(BP, aes(x = Sex,
y = Calcium,
color = Treatment,
group = Treatment)) +
geom_point(position = position_jitter(width = 0.05, seed = 474577),
size = 3) +
stat_summary(fun = mean, geom = "point", pch = 5, size = 5) +
stat_summary(fun = mean, geom = "line") +
scale_color_manual(values = c("gray50", "darkgreen"))Model specification
We have a factorial model, so we would like to model the two main effects: Sex and Treatment as well as the Sex by Treatment interaction term. Interactions between categorical variables are complicated to code in Bayesian models. Although you can just input the model like you would with lm(): Sex * Treatment, specifying the priors might be tricky and getting the posteriors sorted out as well.
One approach that works well in some (most? all?) situations is to create a new composite variable that combines the two other variables. Thus the four factorial groups become a single factor with four levels (Female-Hormone, Female-None, Male-Hormone, and Male-None). Because we are testing hypotheses using contrasts (subtracting posterior distributions), we don’t have to worry about the usual main effects and interaction P-value based hypothesis tests.
You can do this with a simple mutate, joining the two variables:
# FIXME
BP <- BP |>
mutate(Sex_Trt = paste(Sex, Treatment, sep = "_"))One additional advantage of this approach is that you only need to specify a single prior for all the groups.
\[\begin{align} \mathrm{Calcium} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Sex\_Trt}] \\ \end{align}\]
Prior specification and prior predictive check
There are only 5 points per group, so the prior is potentially very powerful relative to the likelihood.
# FIXME
get_prior(Calcium ~ Sex_Trt - 1,
data = BP) prior class coef group resp dpar nlpar lb ub
(flat) b
(flat) b Sex_TrtFemale_Hormone
(flat) b Sex_TrtFemale_None
(flat) b Sex_TrtMale_Hormone
(flat) b Sex_TrtMale_None
student_t(3, 0, 12.1) sigma 0
source
default
(vectorized)
(vectorized)
(vectorized)
(vectorized)
default
PP <- brm(Calcium ~ Sex_Trt - 1,
data = BP,
prior = c(prior(normal(20, 15), class = b),
prior(normal(0, 10), class = sigma)),
refresh = 0,
sample_prior = "only")Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
# Check the minimum predicted vs. observed
pp_check(PP, type = "stat_grouped", group = "Sex_Trt",
stat = "min",
ndraws = 500,
binwidth = 2)# Check the mean predicted vs. observed
pp_check(PP, type = "stat_grouped", group = "Sex_Trt",
stat = "mean",
ndraws = 500,
binwidth = 2)# Check the maximum predicted vs. observed
pp_check(PP, type = "stat_grouped", group = "Sex_Trt",
stat = "max",
ndraws = 500,
binwidth = 2)Final model specification
\[\begin{align} \mathrm{Calcium} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Sex\_Trt}] \\ b[\mathrm{Sex\_Trt}] & \sim Normal(20, 15) \\ \sigma & \sim HalfNormal(0, 10) \end{align}\]
Sampling
# FIXME
fm <- brm(Calcium ~ Sex_Trt - 1,
data = BP,
prior = c(prior(normal(20, 15), class = b),
prior(normal(0, 10), class = sigma)),
refresh = 0,
iter = 5e3)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub
normal(20, 15) b
normal(20, 15) b Sex_TrtFemale_Hormone
normal(20, 15) b Sex_TrtFemale_None
normal(20, 15) b Sex_TrtMale_Hormone
normal(20, 15) b Sex_TrtMale_None
normal(0, 10) sigma 0
source
user
(vectorized)
(vectorized)
(vectorized)
(vectorized)
user
Diagnostics
# FIXME
powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Sex_TrtFemale_Hormone", "b_Sex_TrtFemale_None"))powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Sex_TrtMale_Hormone", "b_Sex_TrtMale_None"))powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("sigma"))powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Sex_TrtFemale_Hormone", "b_Sex_TrtFemale_None"))powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Sex_TrtMale_Hormone", "b_Sex_TrtMale_None"))powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("sigma"))powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 5 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_Sex_TrtFemale_Hormone 0.0180 0.137 -
2 b_Sex_TrtFemale_None 0.00895 0.112 -
3 b_Sex_TrtMale_Hormone 0.0112 0.116 -
4 b_Sex_TrtMale_None 0.0111 0.109 -
5 sigma 0.0120 0.319 -
# FIXME
summary(fm) Family: gaussian
Links: mu = identity; sigma = identity
Formula: Calcium ~ Sex_Trt - 1
Data: BP (Number of observations: 20)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Sex_TrtFemale_Hormone 32.27 2.16 27.94 36.48 1.00 11466
Sex_TrtFemale_None 15.00 2.10 10.92 19.18 1.00 11147
Sex_TrtMale_Hormone 27.65 2.12 23.46 31.83 1.00 11634
Sex_TrtMale_None 12.28 2.10 8.18 16.51 1.00 12866
Tail_ESS
Sex_TrtFemale_Hormone 6627
Sex_TrtFemale_None 6999
Sex_TrtMale_Hormone 6768
Sex_TrtMale_None 6977
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 4.68 0.91 3.31 6.83 1.00 7581 7168
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
BP |>
data_grid(Sex_Trt) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Sex_Trt)) +
stat_slab(alpha = 0.5, fill = "firebrick4")BP |>
data_grid(Sex_Trt) |>
add_predicted_draws(fm) |>
median_hdi(width = 0.89)# A tibble: 4 × 8
Sex_Trt .row width .lower .upper .width .point .interval
<chr> <int> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 Female_Hormone 1 0.89 0.89 0.89 0.95 median hdi
2 Female_None 2 0.89 0.89 0.89 0.95 median hdi
3 Male_Hormone 3 0.89 0.89 0.89 0.95 median hdi
4 Male_None 4 0.89 0.89 0.89 0.95 median hdi
Posterior predictive simulation
# FIXME
# Check the minimum predicted vs. observed
pp_check(fm, type = "stat_grouped", group = "Sex_Trt",
stat = "min",
ndraws = 500,
binwidth = 1)# Check the mean predicted vs. observed
pp_check(fm, type = "stat_grouped", group = "Sex_Trt",
stat = "mean",
ndraws = 500,
binwidth = 1)# Check the maximum predicted vs. observed
pp_check(fm, type = "stat_grouped", group = "Sex_Trt",
stat = "max",
ndraws = 500,
binwidth = 1)Summarizing the posterior
Compare the means of Hormone vs. Control separately by sex.
# FIXME
post <- fm |>
spread_rvars(b_Sex_TrtFemale_Hormone, b_Sex_TrtFemale_None,
b_Sex_TrtMale_Hormone, b_Sex_TrtMale_None) |>
mutate(`F: Horm. vs. C.` = b_Sex_TrtMale_Hormone - b_Sex_TrtMale_None,
`M: Horm. vs. C.` = b_Sex_TrtFemale_Hormone - b_Sex_TrtFemale_None,
.keep = "none")
post |>
pivot_longer(cols = everything()) |>
ggplot(aes(xdist = value, fill = name)) +
stat_slab(alpha = 0.5) +
scale_fill_manual(values = c("darkslateblue", "coral"),
name = "Contrast") +
labs(x = "Difference (Hormone - Control)", y = "Density")median_hdi(post, .width = 0.89)# A tibble: 1 × 9
`F: Horm. vs. C.` `F: Horm. vs. C..lower` `F: Horm. vs. C..upper`
<dbl> <dbl> <dbl>
1 15.4 10.7 20.1
# ℹ 6 more variables: `M: Horm. vs. C.` <dbl>, `M: Horm. vs. C..lower` <dbl>,
# `M: Horm. vs. C..upper` <dbl>, .width <dbl>, .point <chr>, .interval <chr>
# FIXME
# Here's how you would do this analysis using the regular interactions
# coding with *.
# This generates the prior prediction and converts to the group posteriors
# The priors are really weak, because there is so little data to learn
# from.
PP <- brm(Calcium ~ Sex * Treatment,
data = BP,
prior = c(prior(normal(15, 20), class = Intercept),
prior(normal(0, 10), coef = SexMale),
prior(normal(0, 20), coef = TreatmentHormone),
prior(normal(0, 10), coef = SexMale:TreatmentHormone),
prior(normal(0, 15), class = sigma)),
refresh = 0,
sample_prior = "only") |>
spread_draws(b_Intercept, b_SexMale,
b_TreatmentHormone, `b_SexMale:TreatmentHormone`) |>
mutate(Female_None = b_Intercept,
Female_Hormone = b_Intercept + b_TreatmentHormone,
Male_None = b_Intercept + b_SexMale,
Male_Hormone = b_Intercept + b_SexMale + `b_SexMale:TreatmentHormone`,
.keep = "none")Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
PP_long <- pivot_longer(PP, cols = everything(),
names_to = "Sex_Trt",
values_to = "Calcium") |>
separate(col = Sex_Trt, into = c("Sex", "Treatment"), sep = "_")
# See how the Female-None group has the lowest variance, and the Male-Hormone
# group has the highest variance
PP_long |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium),
var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<chr> <chr> <dbl> <dbl>
1 Female Hormone 15.8 516.
2 Female None 15.3 525.
3 Male Hormone 15.5 553.
4 Male None 15.3 511.
# A kind of prior predictive distribution plot
ggplot() +
geom_density(data = PP_long, aes(Calcium)) +
geom_point(data = BP, aes(x = Calcium, y = 0),
shape = 21, fill = "#9ECAE1", size = 3) +
facet_grid(Sex ~ Treatment)# Fit the model
fm <- brm(Calcium ~ Sex * Treatment,
data = BP,
prior = c(prior(normal(15, 20), class = Intercept),
prior(normal(0, 10), coef = SexMale),
prior(normal(0, 20), coef = TreatmentHormone),
prior(normal(0, 10), coef = SexMale:TreatmentHormone),
prior(normal(0, 15), class = sigma)),
refresh = 0,
iter = 5e3)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub
(flat) b
normal(0, 10) b SexMale
normal(0, 10) b SexMale:TreatmentHormone
normal(0, 20) b TreatmentHormone
normal(15, 20) Intercept
normal(0, 15) sigma 0
source
default
user
user
user
user
user
# These all look fine.
powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Intercept", "b_SexMale"))powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_TreatmentHormone", "b_SexMale:TreatmentHormone"))powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 6 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_Intercept 0.0127 0.102 -
2 b_SexMale 0.0138 0.0930 -
3 b_TreatmentHormone 0.0247 0.113 -
4 b_SexMale:TreatmentHormone 0.0242 0.0892 -
5 sigma 0.0140 0.301 -
6 Intercept 0.00383 0.114 -
summary(fm) Family: gaussian
Links: mu = identity; sigma = identity
Formula: Calcium ~ Sex * Treatment
Data: BP (Number of observations: 20)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Intercept 14.97 2.03 10.96 18.97 1.00 5621
SexMale -2.83 2.74 -8.20 2.55 1.00 5038
TreatmentHormone 17.26 2.81 11.64 22.69 1.00 4996
SexMale:TreatmentHormone -1.57 3.79 -8.99 5.94 1.00 4559
Tail_ESS
Intercept 6186
SexMale 5395
TreatmentHormone 5308
SexMale:TreatmentHormone 5058
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 4.66 0.90 3.31 6.83 1.00 6330 5799
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
# Extract posterior and
post <- fm |>
spread_draws(b_Intercept, b_SexMale,
b_TreatmentHormone, `b_SexMale:TreatmentHormone`) |>
mutate(Female_None = b_Intercept,
Female_Hormone = b_Intercept + b_TreatmentHormone,
Male_None = b_Intercept + b_SexMale,
Male_Hormone = b_Intercept + b_SexMale + `b_SexMale:TreatmentHormone`,
.keep = "none")
# The Male-Hormone group retains the wider variance in the posterior, even
# though it does not in the observed data.
pivot_longer(post, cols = everything(),
names_to = "Sex_Trt",
values_to = "Calcium") |>
separate(col = Sex_Trt, into = c("Sex", "Treatment"), sep = "_") |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium),
var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<chr> <chr> <dbl> <dbl>
1 Female Hormone 32.2 4.29
2 Female None 15.0 4.12
3 Male Hormone 10.6 11.6
4 Male None 12.1 4.14
# Roughly equal variances in the observed data
BP |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium),
var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<fct> <fct> <dbl> <dbl>
1 Female None 14.9 17.1
2 Female Hormone 32.5 21.8
3 Male None 12.1 18.0
4 Male Hormone 27.8 18.4
median_hdi(post, .width = 0.89) Female_None Female_None.lower Female_None.upper Female_Hormone
1 14.97165 11.9024 18.2995 32.2491
Female_Hormone.lower Female_Hormone.upper Male_None Male_None.lower
1 29.024 35.576 12.09723 8.88245
Male_None.upper Male_Hormone Male_Hormone.lower Male_Hormone.upper .width
1 15.34086 10.59082 5.02043 15.86541 0.89
.point .interval
1 median hdi
Multiple continuous predictors
Working with multiple continuous predictors also poses some unique challenges (not to mention continuous predictors with interactions). Visualization in particular is not straightforward, because, unless you want a 3D plot, you can’t plot 3 continuous variables (1 outcome + 2 predictors) simultaneously. Options include making separate plots, coloring by one predictor by the other, or choosing specific values at which to visualize the data. And usually doing these reciprocally for the two predictors.
To work through this example, we will use the (apparent) trade-off between fat content and lactose content in mammal milk. We used this example in Quantitative Methods 1 to show how multiple regression is actually working.
Load the data in Milk.xlsx, select the columns kcal.per.g, perc.fat, perc.lactose, rename them to Milk_energy, Fat, and Lactose. We will predict the first by the additive effects of the latter two.
There are some missing values in the data, so drop any rows with NA. These are comparative data for different species of primates, but we will ignore those relationships for this analysis.
# FIXME
#| warning: false
MM <- readxl::read_excel("../Data/Milk.xlsx") |>
select(kcal.per.g, perc.fat, perc.lactose) |>
drop_na() |>
rename(Milk_energy = kcal.per.g,
Fat = perc.fat,
Lactose = perc.lactose)Make two plots, one where energy is predicted by fat and the other by lactose.
# FIXME
p1 <- plot_grid(ggplot(MM, aes(Fat, Milk_energy)) + geom_point(),
ggplot(MM, aes(Lactose, Milk_energy)) + geom_point(),
ncol = 2)
p1You will see that they vary inversely. As fat goes up, lactose goes down. Because there is a finite percentage (100%) of what milk can be made of. As one goes up the other goes down. The third component, protein (mostly casein), makes up the last component. We are ignoring protein.
If you check the correlation between fat and lactose, you will see it’s large (\(r \approx\) -0.94). In a frequentist regression, you might be worried about multicollinearity in this case.
# FIXME
cor(MM$Fat, MM$Lactose)[1] -0.9416373
Model specification
\[\begin{align} \mathrm{Milk\_energy} & \sim Normal(\mu, \sigma) \\ \mu & = b0 + b1 \mathrm{Fat} + b2 \mathrm{Lactose} \\ \end{align}\]
Prior specification and prior predictive check
# FIXME
PP <- brm(Milk_energy ~ Fat + Lactose,
data = MM,
prior = c(prior(normal(0, 0.05), class = b),
prior(normal(0, 5), class = Intercept),
prior(normal(0, 5), class = sigma)),
sample_prior = "only",
refresh = 0)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.7 seconds.
pp_check(PP, ndraws = 100)pp_check(PP, type = "stat", stat = "median", binwidth = 1)Using all posterior draws for ppc type 'stat' by default.
Final model specification
\[\begin{align} \mathrm{Milk\_energy} & \sim Normal(\mu, \sigma) \\ \mu & = b0 + b1 \mathrm{Fat} + b2 \mathrm{Lactose} \\ b0 & \sim Normal(0, 5) \\ b1 & \sim Normal(0, 0.05) \\ b2 & \sim Normal(0, 0.05) \\ sigma & \sim HalfNormal(0, 5) \end{align}\]
Sampling
# FIXME
fm <- brm(Milk_energy ~ Fat + Lactose,
data = MM,
prior = c(prior(normal(0, 0.05), class = b),
prior(normal(0, 5), class = Intercept),
prior(normal(0, 5), class = sigma)),
refresh = 0,
iter = 5e3)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.2 seconds.
Chain 4 finished in 0.2 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 1.1 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub source
normal(0, 0.05) b user
normal(0, 0.05) b Fat (vectorized)
normal(0, 0.05) b Lactose (vectorized)
normal(0, 5) Intercept user
normal(0, 5) sigma 0 user
Diagnostics
# FIXME
mcmc_combo(fm, regex_pars = "^b")mcmc_rank_overlay(fm, regex_pars = "^b")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = "b_Intercept")powerscale_plot_dens(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Fat", "b_Lactose"))powerscale_plot_quantities(
powerscale_sequence(fm),
quantities = c("mean", "sd"),
variables = c("b_Intercept", "b_Fat", "b_Lactose"))powerscale_sensitivity(fm)Sensitivity based on cjs_dist:
# A tibble: 5 × 4
variable prior likelihood diagnosis
<chr> <dbl> <dbl> <chr>
1 b_Intercept 0.00134 0.0889 -
2 b_Fat 0.00126 0.0935 -
3 b_Lactose 0.00136 0.0889 -
4 sigma 0.000247 0.221 -
5 Intercept 0.0000613 0.0903 -
Posterior predictive simulation
Use pp_check() to make a density plot of the observed data superimposed on draws from the posterior.
# FIXME
pp_check(fm, ndraws = 100)To visualize the effect of the two continuous predictors, we’ll have to get creative. Here are the steps:
- Make a grid of observations for prediction. Make a sequence of 200 values between 3 and 56 for
Fat. Specify only three values forLactose: 30, 50, and 70. Each value ofFatwill be associated with three levels ofLactose. - Generate the posterior predictive distributions use
posterior_epred()and the new data you just created. - Calculate the median and 89% HDI intervals using
mutate()like we did in the lecture slides.
# FIXME
pred_values <- crossing(
Fat = seq(3, 56, length.out = 200),
Lactose = c(30, 50, 70)
)
p_pred <- posterior_epred(fm, newdata = pred_values)
pred_values <- pred_values |>
mutate(Q50 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.5),
Q5.5 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.055),
Q94.5 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.945),
Lactose = factor(Lactose))
pred_values# A tibble: 600 × 5
Fat Lactose Q50 Q5.5 Q94.5
<dbl> <fct> <dbl> <dbl> <dbl>
1 3 30 0.750 0.536 0.970
2 3 50 0.577 0.444 0.711
3 3 70 0.403 0.342 0.464
4 3.27 30 0.751 0.538 0.969
5 3.27 50 0.577 0.446 0.710
6 3.27 70 0.404 0.344 0.463
7 3.53 30 0.751 0.539 0.969
8 3.53 50 0.578 0.447 0.710
9 3.53 70 0.404 0.345 0.463
10 3.80 30 0.752 0.541 0.968
# ℹ 590 more rows
You should have a tibble of 600 x 5 columns, with columns for Fat, Lactose, Q50, Q5.5, and Q94.5.
Make a ribbon plot of the 89% interval, add a line for the median, and facet by Lactose in 3 columns. You should be able to see what the model predicts for milk energy as a function of fat at the three levels of lactose.
It will take some staring at this plot to make sense of it. Pay particular attention to the places where the model is pretty sure (narrow bands) or unsure (wide bands).
If you make a composite plot with the pair of scatterplots from the first chunk in this example in one row and this new plot in row 2, it might help to make sense of the output.
# FIXME
p2 <- ggplot() +
geom_ribbon(data = pred_values,
aes(x = Fat, ymin = Q5.5, ymax = Q94.5,
fill = Lactose), alpha = 0.25) +
geom_line(data = pred_values,
aes(x = Fat, y = Q50, color = Lactose)) +
facet_grid(. ~ Lactose) +
scale_color_viridis_d(option = "D") +
scale_fill_viridis_d(option = "D") +
labs(x = "Fat Percentage", y = "Milk Energy")
plot_grid(p1, p2, nrow = 2)Summarizing the posterior
Summarize the posterior however you think is appropriate.
# FIXME
summary(fm) |> print(digits = 4) Family: gaussian
Links: mu = identity; sigma = identity
Formula: Milk_energy ~ Fat + Lactose
Data: MM (Number of observations: 29)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 1.0058 0.2254 0.5598 1.4537 1.0003 4636 5146
Fat 0.0020 0.0027 -0.0034 0.0073 1.0006 4619 4970
Lactose -0.0087 0.0027 -0.0142 -0.0033 1.0004 4746 5338
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.0677 0.0099 0.0514 0.0903 1.0005 5200 5598
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Footnotes
Kallioinen, N., T. Paananen, P.-C. Bürkner, and A. Vehtari. 2024. Detecting and diagnosing prior and likelihood sensitivity with power-scaling. Stat. Comput. 34.↩︎